{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='1'></a>\n",
    "# Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.layers import *\n",
    "import keras.backend as K\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import glob\n",
    "import time\n",
    "import numpy as np\n",
    "from pathlib import PurePath, Path\n",
    "from IPython.display import clear_output\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='4'></a>\n",
    "# Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#K.set_learning_phase(1)\n",
    "#K.set_learning_phase(0) # set to 0 in inference phase"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of CPU cores\n",
    "num_cpus = os.cpu_count()\n",
    "\n",
    "# Input/Output resolution\n",
    "RESOLUTION = 64 # 64x64, 128x128, 256x256\n",
    "assert (RESOLUTION % 64) == 0, \"RESOLUTION should be 64, 128, or 256.\"\n",
    "\n",
    "# Batch size\n",
    "batchSize = 8\n",
    "assert (batchSize != 1 and batchSize % 2 == 0) , \"batchSize should be an even number.\"\n",
    "\n",
    "# Use motion blurs (data augmentation)\n",
    "# set True if training data contains images extracted from videos\n",
    "use_da_motion_blur = False \n",
    "\n",
    "# Use eye-aware training\n",
    "# require images generated from prep_binary_masks.ipynb\n",
    "use_bm_eyes = True\n",
    "\n",
    "# Probability of random color matching (data augmentation)\n",
    "prob_random_color_match = 0.5\n",
    "\n",
    "da_config = {\n",
    "    \"prob_random_color_match\": prob_random_color_match,\n",
    "    \"use_da_motion_blur\": use_da_motion_blur,\n",
    "    \"use_bm_eyes\": use_bm_eyes\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Path to training images\n",
    "img_dirA = './faceA'\n",
    "img_dirB = './faceB'\n",
    "img_dirA_bm_eyes = \"./binary_masks/faceA_eyes\"\n",
    "img_dirB_bm_eyes = \"./binary_masks/faceB_eyes\"\n",
    "\n",
    "# Path to saved model weights\n",
    "models_dir = \"./models\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Architecture configuration\n",
    "arch_config = {}\n",
    "arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)\n",
    "arch_config['use_self_attn'] = True\n",
    "arch_config['norm'] = \"instancenorm\" # instancenorm, batchnorm, layernorm, groupnorm, none\n",
    "arch_config['model_capacity'] = \"standard\" # standard, lite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Loss function weights configuration\n",
    "loss_weights = {}\n",
    "loss_weights['w_D'] = 0.1 # Discriminator\n",
    "loss_weights['w_recon'] = 1. # L1 reconstruction loss\n",
    "loss_weights['w_edge'] = 0.1 # edge loss\n",
    "loss_weights['w_eyes'] = 30. # reconstruction and edge loss on eyes area\n",
    "loss_weights['w_pl'] = (0.01, 0.1, 0.3, 0.1) # perceptual loss (0.003, 0.03, 0.3, 0.3)\n",
    "\n",
    "# Init. loss config.\n",
    "loss_config = {}\n",
    "loss_config[\"gan_training\"] = \"mixup_LSGAN\" # \"mixup_LSGAN\" or \"relativistic_avg_LSGAN\"\n",
    "loss_config['use_PL'] = False\n",
    "loss_config[\"PL_before_activ\"] = False\n",
    "loss_config['use_mask_hinge_loss'] = False\n",
    "loss_config['m_mask'] = 0.\n",
    "loss_config['lr_factor'] = 1.\n",
    "loss_config['use_cyclic_loss'] = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='5'></a>\n",
    "# Define models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from networks.faceswap_gan_model import FaceswapGANModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = FaceswapGANModel(**arch_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='6'></a>\n",
    "# Load Model Weights\n",
    "\n",
    "Weights file names:\n",
    "```python\n",
    "encoder.h5, decoder_A.h5, deocder_B.h5, netDA.h5, netDB.h5\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model weights files are successfully loaded\n"
     ]
    }
   ],
   "source": [
    "model.load_weights(path=models_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### The following cells are for training, skip to [transform_face()](#tf) for inference.\n",
    "\n",
    "# Define Losses and Build Training Functions\n",
    "\n",
    "TODO: split into two methods"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If it throws errors building vggface ResNet (probably due to Keras version), the following code is what we did to make it runnable on Colab.\n",
    "\n",
    "```python\n",
    "!wget https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_tf_notop_resnet50.h5\n",
    "from colab_demo.vggface_models import RESNET50\n",
    "vggface = RESNET50(include_top=False, weights=None, input_shape=(224, 224, 3))\n",
    "vggface.load_weights(\"rcmalli_vggface_tf_notop_resnet50.h5\")\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://github.com/rcmalli/keras-vggface\n",
    "#!pip install keras_vggface --no-dependencies\n",
    "from keras_vggface.vggface import VGGFace\n",
    "\n",
    "# VGGFace ResNet50\n",
    "vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))\n",
    "\n",
    "#vggface.summary()\n",
    "\n",
    "model.build_pl_model(vggface_model=vggface, before_activ=loss_config[\"PL_before_activ\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.build_train_functions(loss_weights=loss_weights, **loss_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "<a id='9'></a>\n",
    "# DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data_loader.data_loader import DataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualizer\n",
    "\n",
    "TODO: write a Visualizer class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import showG, showG_mask, showG_eyes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='10'></a>\n",
    "# Start Training\n",
    "TODO: make training script compact"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create ./models directory\n",
    "Path(f\"models\").mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of images in folder A: 376\n",
      "Number of images in folder B: 318\n"
     ]
    }
   ],
   "source": [
    "# Get filenames\n",
    "train_A = glob.glob(img_dirA+\"/*.*\")\n",
    "train_B = glob.glob(img_dirB+\"/*.*\")\n",
    "\n",
    "train_AnB = train_A + train_B\n",
    "\n",
    "assert len(train_A), \"No image found in \" + str(img_dirA)\n",
    "assert len(train_B), \"No image found in \" + str(img_dirB)\n",
    "print (\"Number of images in folder A: \" + str(len(train_A)))\n",
    "print (\"Number of images in folder B: \" + str(len(train_B)))\n",
    "\n",
    "if use_bm_eyes:\n",
    "    assert len(glob.glob(img_dirA_bm_eyes+\"/*.*\")), \"No binary mask found in \" + str(img_dirA_bm_eyes)\n",
    "    assert len(glob.glob(img_dirB_bm_eyes+\"/*.*\")), \"No binary mask found in \" + str(img_dirB_bm_eyes)\n",
    "    assert len(glob.glob(img_dirA_bm_eyes+\"/*.*\")) == len(train_A), \\\n",
    "    \"Number of faceA images does not match number of their binary masks. Can be caused by any none image file in the folder.\"\n",
    "    assert len(glob.glob(img_dirB_bm_eyes+\"/*.*\")) == len(train_B), \\\n",
    "    \"Number of faceB images does not match number of their binary masks. Can be caused by any none image file in the folder.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_loss_config(loss_config):\n",
    "    for config, value in loss_config.items():\n",
    "        print(f\"{config} = {value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display random binary masks of eyes\n",
    "train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes, \n",
    "                          RESOLUTION, num_cpus, K.get_session(), **da_config)\n",
    "train_batchB = DataLoader(train_B, train_AnB, batchSize, img_dirB_bm_eyes, \n",
    "                          RESOLUTION, num_cpus, K.get_session(), **da_config)\n",
    "_, tA, bmA = train_batchA.get_next_batch()\n",
    "_, tB, bmB = train_batchB.get_next_batch()\n",
    "showG_eyes(tA, tB, bmA, bmB, batchSize)\n",
    "del train_batchA, train_batchB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def reset_session(save_path):\n",
    "    global model, vggface\n",
    "    global train_batchA, train_batchB\n",
    "    model.save_weights(path=save_path)\n",
    "    del model\n",
    "    del vggface\n",
    "    del train_batchA\n",
    "    del train_batchB\n",
    "    K.clear_session()\n",
    "    model = FaceswapGANModel(**arch_config)\n",
    "    model.load_weights(path=save_path)\n",
    "    vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))\n",
    "    model.build_pl_model(vggface_model=vggface, before_activ=loss_config[\"PL_before_activ\"])\n",
    "    train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes,\n",
    "                              RESOLUTION, num_cpus, K.get_session(), **da_config)\n",
    "    train_batchB = DataLoader(train_B, train_AnB, batchSize, img_dirB_bm_eyes, \n",
    "                              RESOLUTION, num_cpus, K.get_session(), **da_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Start training\n",
    "t0 = time.time()\n",
    "\n",
    "# This try/except is meant to resume training that was accidentally interrupted\n",
    "try:\n",
    "    gen_iterations\n",
    "    print(f\"Resume training from iter {gen_iterations}.\")\n",
    "except:\n",
    "    gen_iterations = 0\n",
    "\n",
    "errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0\n",
    "errGAs = {}\n",
    "errGBs = {}\n",
    "# Dictionaries are ordered in Python 3.6\n",
    "for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:\n",
    "    errGAs[k] = 0\n",
    "    errGBs[k] = 0\n",
    "\n",
    "display_iters = 300\n",
    "backup_iters = 5000\n",
    "TOTAL_ITERS = 40000\n",
    "\n",
    "global train_batchA, train_batchB\n",
    "train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes, \n",
    "                          RESOLUTION, num_cpus, K.get_session(), **da_config)\n",
    "train_batchB = DataLoader(train_B, train_AnB, batchSize, img_dirB_bm_eyes, \n",
    "                          RESOLUTION, num_cpus, K.get_session(), **da_config)\n",
    "\n",
    "while gen_iterations <= TOTAL_ITERS: \n",
    "    \n",
    "    # Loss function automation\n",
    "    if gen_iterations == (TOTAL_ITERS//5 - display_iters//2):\n",
    "        clear_output()\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = False\n",
    "        loss_config['m_mask'] = 0.0\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Done.\")\n",
    "    elif gen_iterations == (TOTAL_ITERS//5 + TOTAL_ITERS//10 - display_iters//2):\n",
    "        clear_output()\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = True\n",
    "        loss_config['m_mask'] = 0.5\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Complete.\")\n",
    "    elif gen_iterations == (2*TOTAL_ITERS//5 - display_iters//2):\n",
    "        clear_output()\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = True\n",
    "        loss_config['m_mask'] = 0.2\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Done.\")\n",
    "    elif gen_iterations == (TOTAL_ITERS//2 - display_iters//2):\n",
    "        clear_output()\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = True\n",
    "        loss_config['m_mask'] = 0.4\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Done.\")\n",
    "    elif gen_iterations == (2*TOTAL_ITERS//3 - display_iters//2):\n",
    "        clear_output()\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = False\n",
    "        loss_config['m_mask'] = 0.\n",
    "        loss_config['lr_factor'] = 0.3\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Done.\")\n",
    "    elif gen_iterations == (8*TOTAL_ITERS//10 - display_iters//2):\n",
    "        clear_output()\n",
    "        model.decoder_A.load_weights(\"models/decoder_B.h5\") # swap decoders\n",
    "        model.decoder_B.load_weights(\"models/decoder_A.h5\") # swap decoders\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = True\n",
    "        loss_config['m_mask'] = 0.1\n",
    "        loss_config['lr_factor'] = 0.3\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Done.\")\n",
    "    elif gen_iterations == (9*TOTAL_ITERS//10 - display_iters//2):\n",
    "        clear_output()\n",
    "        loss_config['use_PL'] = True\n",
    "        loss_config['use_mask_hinge_loss'] = False\n",
    "        loss_config['m_mask'] = 0.0\n",
    "        loss_config['lr_factor'] = 0.1\n",
    "        reset_session(models_dir)\n",
    "        print(\"Building new loss funcitons...\")\n",
    "        show_loss_config(loss_config)\n",
    "        model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "        print(\"Done.\")\n",
    "    \n",
    "    if gen_iterations == 5:\n",
    "        print (\"working.\")\n",
    "    \n",
    "    # Train dicriminators for one batch\n",
    "    data_A = train_batchA.get_next_batch()\n",
    "    data_B = train_batchB.get_next_batch()\n",
    "    errDA, errDB = model.train_one_batch_D(data_A=data_A, data_B=data_B)\n",
    "    errDA_sum +=errDA[0]\n",
    "    errDB_sum +=errDB[0]\n",
    "\n",
    "    # Train generators for one batch\n",
    "    data_A = train_batchA.get_next_batch()\n",
    "    data_B = train_batchB.get_next_batch()\n",
    "    errGA, errGB = model.train_one_batch_G(data_A=data_A, data_B=data_B)\n",
    "    errGA_sum += errGA[0]\n",
    "    errGB_sum += errGB[0]\n",
    "    for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):\n",
    "        errGAs[k] += errGA[i]\n",
    "        errGBs[k] += errGB[i]\n",
    "    gen_iterations+=1\n",
    "    \n",
    "    # Visualization\n",
    "    if gen_iterations % display_iters == 0:\n",
    "        clear_output()\n",
    "            \n",
    "        # Display loss information\n",
    "        show_loss_config(loss_config)\n",
    "        print(\"----------\") \n",
    "        print('[iter %d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f time: %f'\n",
    "        % (gen_iterations, errDA_sum/display_iters, errDB_sum/display_iters,\n",
    "           errGA_sum/display_iters, errGB_sum/display_iters, time.time()-t0))  \n",
    "        print(\"----------\") \n",
    "        print(\"Generator loss details:\")\n",
    "        print(f'[Adversarial loss]')  \n",
    "        print(f'GA: {errGAs[\"adv\"]/display_iters:.4f} GB: {errGBs[\"adv\"]/display_iters:.4f}')\n",
    "        print(f'[Reconstruction loss]')\n",
    "        print(f'GA: {errGAs[\"recon\"]/display_iters:.4f} GB: {errGBs[\"recon\"]/display_iters:.4f}')\n",
    "        print(f'[Edge loss]')\n",
    "        print(f'GA: {errGAs[\"edge\"]/display_iters:.4f} GB: {errGBs[\"edge\"]/display_iters:.4f}')\n",
    "        if loss_config['use_PL'] == True:\n",
    "            print(f'[Perceptual loss]')\n",
    "            try:\n",
    "                print(f'GA: {errGAs[\"pl\"][0]/display_iters:.4f} GB: {errGBs[\"pl\"][0]/display_iters:.4f}')\n",
    "            except:\n",
    "                print(f'GA: {errGAs[\"pl\"]/display_iters:.4f} GB: {errGBs[\"pl\"]/display_iters:.4f}')\n",
    "        \n",
    "        # Display images\n",
    "        print(\"----------\") \n",
    "        wA, tA, _ = train_batchA.get_next_batch()\n",
    "        wB, tB, _ = train_batchB.get_next_batch()\n",
    "        print(\"Transformed (masked) results:\")\n",
    "        showG(tA, tB, model.path_A, model.path_B, batchSize)   \n",
    "        print(\"Masks:\")\n",
    "        showG_mask(tA, tB, model.path_mask_A, model.path_mask_B, batchSize)  \n",
    "        print(\"Reconstruction results:\")\n",
    "        showG(wA, wB, model.path_bgr_A, model.path_bgr_B, batchSize)           \n",
    "        errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0\n",
    "        for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:\n",
    "            errGAs[k] = 0\n",
    "            errGBs[k] = 0\n",
    "        \n",
    "        # Save models\n",
    "        model.save_weights(path=models_dir)\n",
    "    \n",
    "    # Backup models\n",
    "    if gen_iterations % backup_iters == 0: \n",
    "        bkup_dir = f\"{models_dir}/backup_iter{gen_iterations}\"\n",
    "        Path(bkup_dir).mkdir(parents=True, exist_ok=True)\n",
    "        model.save_weights(path=bkup_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display random results\n",
    "wA, tA, _ = train_batchA.get_next_batch()\n",
    "wB, tB, _ = train_batchB.get_next_batch()\n",
    "print(\"Transformed (masked) results:\")\n",
    "showG(tA, tB, model.path_A, model.path_B, batchSize)   \n",
    "print(\"Masks:\")\n",
    "showG_mask(tA, tB, model.path_mask_A, model.path_mask_B, batchSize)  \n",
    "print(\"Reconstruction results:\")\n",
    "showG(wA, wB, model.path_bgr_A, model.path_bgr_B, batchSize) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# (Optional) Additional 40k iterations of training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "loss_config['use_PL'] = True\n",
    "loss_config['use_mask_hinge_loss'] = False\n",
    "loss_config['m_mask'] = 0.0\n",
    "loss_config['lr_factor'] = 0.1\n",
    "reset_session(models_dir)\n",
    "print(\"Building new loss funcitons...\")\n",
    "show_loss_config(loss_config)\n",
    "model.build_train_functions(loss_weights=loss_weights, **loss_config)\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "# Start training\n",
    "t0 = time.time()\n",
    "gen_iterations = 0\n",
    "errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0\n",
    "errGAs = {}\n",
    "errGBs = {}\n",
    "# Dictionaries are ordered in Python 3.6\n",
    "for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:\n",
    "    errGAs[k] = 0\n",
    "    errGBs[k] = 0\n",
    "\n",
    "display_iters = 300\n",
    "backup_iters = 5000\n",
    "TOTAL_ITERS = 40000\n",
    "\n",
    "global train_batchA, train_batchB\n",
    "train_batchA = DataLoader(train_A, train_AnB, batchSize, img_dirA_bm_eyes, \n",
    "                          RESOLUTION, num_cpus, K.get_session(), **da_config)\n",
    "train_batchB = DataLoader(train_B, train_AnB, batchSize, img_dirB_bm_eyes, \n",
    "                          RESOLUTION, num_cpus, K.get_session(), **da_config)\n",
    "\n",
    "while gen_iterations <= TOTAL_ITERS: \n",
    "    \n",
    "    if gen_iterations == 5:\n",
    "        print (\"working.\")\n",
    "    \n",
    "    # Train dicriminators for one batch\n",
    "    data_A = train_batchA.get_next_batch()\n",
    "    data_B = train_batchB.get_next_batch()\n",
    "    errDA, errDB = model.train_one_batch_D(data_A=data_A, data_B=data_B)\n",
    "    errDA_sum +=errDA[0]\n",
    "    errDB_sum +=errDB[0]\n",
    "\n",
    "    # Train generators for one batch\n",
    "    data_A = train_batchA.get_next_batch()\n",
    "    data_B = train_batchB.get_next_batch()\n",
    "    errGA, errGB = model.train_one_batch_G(data_A=data_A, data_B=data_B)\n",
    "    errGA_sum += errGA[0]\n",
    "    errGB_sum += errGB[0]\n",
    "    for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):\n",
    "        errGAs[k] += errGA[i]\n",
    "        errGBs[k] += errGB[i]\n",
    "    gen_iterations+=1\n",
    "    \n",
    "    # Visualization\n",
    "    if gen_iterations % display_iters == 0:\n",
    "        clear_output()\n",
    "            \n",
    "        # Display loss information\n",
    "        show_loss_config(loss_config)\n",
    "        print(\"----------\") \n",
    "        print('[iter %d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f time: %f'\n",
    "        % (gen_iterations, errDA_sum/display_iters, errDB_sum/display_iters,\n",
    "           errGA_sum/display_iters, errGB_sum/display_iters, time.time()-t0))  \n",
    "        print(\"----------\") \n",
    "        print(\"Generator loss details:\")\n",
    "        print(f'[Adversarial loss]')  \n",
    "        print(f'GA: {errGAs[\"adv\"]/display_iters:.4f} GB: {errGBs[\"adv\"]/display_iters:.4f}')\n",
    "        print(f'[Reconstruction loss]')\n",
    "        print(f'GA: {errGAs[\"recon\"]/display_iters:.4f} GB: {errGBs[\"recon\"]/display_iters:.4f}')\n",
    "        print(f'[Edge loss]')\n",
    "        print(f'GA: {errGAs[\"edge\"]/display_iters:.4f} GB: {errGBs[\"edge\"]/display_iters:.4f}')\n",
    "        if loss_config['use_PL'] == True:\n",
    "            print(f'[Perceptual loss]')\n",
    "            try:\n",
    "                print(f'GA: {errGAs[\"pl\"][0]/display_iters:.4f} GB: {errGBs[\"pl\"][0]/display_iters:.4f}')\n",
    "            except:\n",
    "                print(f'GA: {errGAs[\"pl\"]/display_iters:.4f} GB: {errGBs[\"pl\"]/display_iters:.4f}')\n",
    "        \n",
    "        # Display images\n",
    "        print(\"----------\") \n",
    "        wA, tA, _ = train_batchA.get_next_batch()\n",
    "        wB, tB, _ = train_batchB.get_next_batch()\n",
    "        print(\"Transformed (masked) results:\")\n",
    "        showG(tA, tB, model.path_A, model.path_B, batchSize)   \n",
    "        print(\"Masks:\")\n",
    "        showG_mask(tA, tB, model.path_mask_A, model.path_mask_B, batchSize)  \n",
    "        print(\"Reconstruction results:\")\n",
    "        showG(wA, wB, model.path_bgr_A, model.path_bgr_B, batchSize)           \n",
    "        errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0\n",
    "        for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:\n",
    "            errGAs[k] = 0\n",
    "            errGBs[k] = 0\n",
    "        \n",
    "        # Save models\n",
    "        model.save_weights(path=models_dir)\n",
    "    \n",
    "    # Backup models\n",
    "    if gen_iterations % backup_iters == 0: \n",
    "        bkup_dir = f\"{models_dir}/backup_iter{gen_iterations}\"\n",
    "        Path(bkup_dir).mkdir(parents=True, exist_ok=True)\n",
    "        model.save_weights(path=bkup_dir)\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id='tf'></a>\n",
    "# Single Image Transformation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from detector.face_detector import MTCNNFaceDetector\n",
    "from converter.landmarks_alignment import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "mtcnn_weights_dir = \"./mtcnn_weights/\"\n",
    "fd = MTCNNFaceDetector(sess=K.get_session(), model_path=mtcnn_weights_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from converter.face_transformer import FaceTransformer\n",
    "ftrans = FaceTransformer()\n",
    "ftrans.set_model(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read input image\n",
    "input_img = plt.imread(\"./TEST_IMAGE.jpg\")[...,:3]\n",
    "\n",
    "if input_img.dtype == np.float32:\n",
    "    print(\"input_img has dtype np.float32 (perhaps the image format is PNG). Scale it to uint8.\")\n",
    "    input_img = (input_img * 255).astype(np.uint8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display input image\n",
    "plt.imshow(input_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display detected face\n",
    "face, lms = fd.detect_face(input_img)\n",
    "if len(face) == 1:\n",
    "    x0, y1, x1, y0, _ = face[0]\n",
    "    det_face_im = input_img[int(x0):int(x1),int(y0):int(y1),:]\n",
    "    try:\n",
    "        src_landmarks = get_src_landmarks(x0, x1, y0, y1, lms)\n",
    "        tar_landmarks = get_tar_landmarks(det_face_im)\n",
    "        aligned_det_face_im = landmarks_match_mtcnn(det_face_im, src_landmarks, tar_landmarks)\n",
    "    except:\n",
    "        print(\"An error occured during face alignment.\")\n",
    "        aligned_det_face_im = det_face_im\n",
    "elif len(face) == 0:\n",
    "    raise ValueError(\"Error: no face detected.\")\n",
    "elif len(face) > 1:\n",
    "    print (face)\n",
    "    raise ValueError(\"Error: multiple faces detected\")\n",
    "    \n",
    "plt.imshow(aligned_det_face_im)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transform detected face\n",
    "result_img, result_rgb, result_mask = ftrans.transform(\n",
    "                    aligned_det_face_im, \n",
    "                    direction=\"AtoB\", \n",
    "                    roi_coverage=0.93,\n",
    "                    color_correction=\"adain_xyz\",\n",
    "                    IMAGE_SHAPE=(RESOLUTION, RESOLUTION, 3)\n",
    "                    )\n",
    "try:\n",
    "    result_img = landmarks_match_mtcnn(result_img, tar_landmarks, src_landmarks)\n",
    "    result_rgb = landmarks_match_mtcnn(result_rgb, tar_landmarks, src_landmarks)\n",
    "    result_mask = landmarks_match_mtcnn(result_mask, tar_landmarks, src_landmarks)\n",
    "except:\n",
    "    print(\"An error occured during face alignment.\")\n",
    "    pass\n",
    "\n",
    "result_input_img = input_img.copy()\n",
    "result_input_img[int(x0):int(x1),int(y0):int(y1),:] = result_mask.astype(np.float32)/255*result_rgb +\\\n",
    "(1-result_mask.astype(np.float32)/255)*result_input_img[int(x0):int(x1),int(y0):int(y1),:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show result face\n",
    "plt.imshow(result_input_img[int(x0):int(x1),int(y0):int(y1),:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show transformed image before masking\n",
    "plt.imshow(result_rgb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show alpha mask\n",
    "plt.imshow(result_mask[..., 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display interpolations before/after transformation\n",
    "def interpolate_imgs(im1, im2):\n",
    "    im1, im2 = map(np.float32, [im1, im2])\n",
    "    out = [ratio * im1 + (1-ratio) * im2 for ratio in np.linspace(1, 0, 5)]\n",
    "    out = map(np.uint8, out)\n",
    "    return out\n",
    "\n",
    "plt.figure(figsize=(15,8))\n",
    "plt.imshow(np.hstack(interpolate_imgs(input_img, result_input_img)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
